/* * Copyright (c) 2014 Villu Ruusmann * * This file is part of Openscoring * * Openscoring is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * Openscoring is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Openscoring. If not, see <http://www.gnu.org/licenses/>. */ package org.openscoring.service; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.ArrayList; import java.util.Collection; import java.util.Date; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import javax.inject.Inject; import javax.inject.Named; import javax.inject.Singleton; import javax.xml.bind.JAXBException; import javax.xml.bind.Marshaller; import javax.xml.bind.Unmarshaller; import javax.xml.bind.ValidationEvent; import javax.xml.bind.ValidationEventHandler; import javax.xml.transform.Result; import javax.xml.transform.Source; import javax.xml.transform.sax.SAXSource; import javax.xml.transform.stream.StreamResult; import javax.xml.validation.Schema; import com.google.common.base.Preconditions; import com.google.common.hash.Hashing; import com.google.common.hash.HashingInputStream; import com.google.common.io.CountingInputStream; import com.typesafe.config.Config; import org.dmg.pmml.PMML; import org.dmg.pmml.Visitor; import org.jpmml.evaluator.ModelEvaluator; import org.jpmml.evaluator.ModelEvaluatorFactory; import org.jpmml.model.ImportFilter; import org.jpmml.model.JAXBUtil; import org.jvnet.hk2.annotations.Service; import org.xml.sax.InputSource; import org.xml.sax.SAXException; import org.xml.sax.XMLReader; import org.xml.sax.helpers.XMLReaderFactory; @Service @Singleton public class ModelRegistry { private List<Class<? extends Visitor>> visitorClazzes = new ArrayList<>(); private boolean validate = false; private ConcurrentMap<String, Model> models = new ConcurrentHashMap<>(); @Inject public ModelRegistry(@Named("openscoring") Config config){ Config modelRegistryConfig = config.getConfig("modelRegistry"); List<String> visitorClassNames = modelRegistryConfig.getStringList("visitorClasses"); for(String visitorClassName : visitorClassNames){ Class<?> clazz; try { clazz = Class.forName(visitorClassName); } catch(ClassNotFoundException cnfe){ throw new IllegalArgumentException(cnfe); } Class<? extends Visitor> visitorClazz; try { visitorClazz = clazz.asSubclass(Visitor.class); } catch(ClassCastException cce){ throw new IllegalArgumentException(cce); } this.visitorClazzes.add(visitorClazz); } this.validate = modelRegistryConfig.getBoolean("validate"); } public Collection<Map.Entry<String, Model>> entries(){ return this.models.entrySet(); } @SuppressWarnings ( value = {"resource"} ) public Model load(InputStream is) throws Exception { CountingInputStream countingIs = new CountingInputStream(is); HashingInputStream hashingIs = new HashingInputStream(Hashing.md5(), countingIs); ModelEvaluator<?> evaluator = unmarshal(hashingIs, this.validate); PMML pmml = evaluator.getPMML(); for(Class<? extends Visitor> visitorClazz : this.visitorClazzes){ Visitor visitor = visitorClazz.newInstance(); visitor.applyTo(pmml); } evaluator.verify(); Model model = new Model(evaluator); model.putProperty(Model.PROPERTY_FILE_SIZE, countingIs.getCount()); model.putProperty(Model.PROPERTY_FILE_MD5SUM, (hashingIs.hash()).toString()); return model; } public void store(Model model, OutputStream os) throws JAXBException { ModelEvaluator<?> evaluator = model.getEvaluator(); marshal(evaluator, os); } public Model get(String id){ return get(id, false); } public Model get(String id, boolean touch){ Model model = this.models.get(id); if(model != null && touch){ model.putProperty(Model.PROPERTY_ACCESSED_TIMESTAMP, new Date()); } return model; } public boolean put(String id, Model model){ Model oldModel = this.models.putIfAbsent(id, Preconditions.checkNotNull(model)); return (oldModel == null); } public boolean replace(String id, Model oldModel, Model model){ return this.models.replace(id, oldModel, Preconditions.checkNotNull(model)); } public boolean remove(String id, Model model){ return this.models.remove(id, model); } static public boolean validateId(String id){ return (id != null && (id).matches(ID_REGEX)); } static private ModelEvaluator<?> unmarshal(InputStream is, boolean validate) throws IOException, SAXException, JAXBException { XMLReader reader = XMLReaderFactory.createXMLReader(); reader.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true); ImportFilter filter = new ImportFilter(reader); Source source = new SAXSource(filter, new InputSource(is)); Unmarshaller unmarshaller = JAXBUtil.createUnmarshaller(); unmarshaller.setEventHandler(new SimpleValidationEventHandler()); if(validate){ Schema schema = JAXBUtil.getSchema(); unmarshaller.setSchema(schema); } PMML pmml = (PMML)unmarshaller.unmarshal(source); ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); return modelEvaluatorFactory.newModelEvaluator(pmml); } static private void marshal(ModelEvaluator<?> evaluator, OutputStream os) throws JAXBException { PMML pmml = evaluator.getPMML(); Result result = new StreamResult(os); Marshaller marshaller = JAXBUtil.createMarshaller(); marshaller.marshal(pmml, result); } static private class SimpleValidationEventHandler implements ValidationEventHandler { @Override public boolean handleEvent(ValidationEvent event){ int severity = event.getSeverity(); switch(severity){ case ValidationEvent.ERROR: case ValidationEvent.FATAL_ERROR: return false; default: return true; } } } public static final String ID_REGEX = "[a-zA-Z0-9][a-zA-Z0-9\\_\\-]*"; }